Bayesian Inference for the Bazin model

In [1]:
from jax import jit
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dists

import holoviews as hv
from jax import nn
hv.extension('bokeh')
hv.opts.defaults(hv.opts.Curve(width=500, show_grid=True), 
                 hv.opts.Spread(width=500, show_grid=True),
                 hv.opts.Distribution(width=200, height=200, show_grid=True),
                 hv.opts.Bivariate(cmap='Blues', line_width=0, filled=True, width=350))

We define the Bazin model a generate a light curve. A well calibrated inference should recover the true parameters

In [2]:
@jit
def bazin_model(time, t0, A, tfall, trise, eps=1e-8):
    diff = (time - t0)
    return A*jnp.exp(-diff/(tfall+eps))*nn.sigmoid(diff/(trise + eps))
    
def bayesian_bazin(time, err=1.0, mag=None):
    # Bazin reparameterizations
    logtfall = numpyro.sample('logtfall', dists.Normal(3.47, 0.41))
    logA = numpyro.sample('logA', dists.Normal(0.397, 0.133))
    t0 = numpyro.sample('t0', dists.Normal(50., 5.))
    logitgamma = numpyro.sample('gamma', dists.Normal(-3., 1.))
    # Bazin parameters
    tfall = numpyro.deterministic('tfall', jnp.exp(logtfall) + 3.)
    trise = numpyro.deterministic('trise', tfall*nn.sigmoid(logitgamma))
    A = numpyro.deterministic('A', jnp.exp(logA))
    # Deterministic model
    with numpyro.plate('data', size=len(time)):
        f = numpyro.deterministic('f', bazin_model(time, t0, A, tfall, trise))
        y = numpyro.sample('obs', dists.Normal(f, err), obs=mag)
    return f, y, t0, A, tfall, trise
    

N = 100
key = random.PRNGKey(1234)
key, key_ = random.split(key)

time = jnp.sort(random.uniform(key_, shape=(N,)))*100
time_inference = jnp.linspace(0, 100, num=1000)
err = jnp.ones_like(time)*0.1

key, key_ = random.split(key)
with numpyro.handlers.seed(rng_seed=key_):
    f_true, mag, true_t0, true_A, true_tfall, true_trise = bayesian_bazin(time, err)

observed_mag_plot = hv.ErrorBars((time, mag, err), kdims='time', vdims=['flux', 'fluxerr'], label='data')
true_mag_plot = hv.Curve((time_inference, bazin_model(time_inference, true_t0, true_A, true_tfall, true_trise)), 
                         'time', 'flux', label='underlying model')
observed_mag_plot * true_mag_plot
Out[2]:

Let's start with MCMC

In [3]:
from numpyro.infer import MCMC, NUTS, init_to_median

sampler = MCMC(sampler=NUTS(bayesian_bazin, init_strategy=init_to_median), 
               num_chains=2, num_samples=2000, num_warmup=200, 
               chain_method='sequential', jit_model_args=True)

key, key_ = random.split(key)
sampler.run(key_, time, err, mag)
#print(sampler.print_summary(prob=0.9))

samples = sampler.get_samples(group_by_chain=True)

def plot_trace(trace):
    names = list(trace.keys())
    plots = []
    for name in ['A', 'tfall', 'trise', 't0']:
        plot_param = []
        for i, chain in enumerate(samples[name]):
            plot_param.append(hv.Curve((chain), vdims=name, label=f'Chain {i}'))
        plots.append(hv.Overlay(plot_param))
    return hv.Layout(plots).cols(1).opts(hv.opts.Curve(width=600, height=150))
                            
plot_trace(samples)
  0%|                         | 0/2200 [00:00<?, ?it/s]/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
sample: 100%|â–ˆ| 2200/2200 [00:08<00:00, 274.65it/s, 7 s
sample: 100%|â–ˆ| 2200/2200 [00:02<00:00, 811.56it/s, 7 s
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
Out[3]:

Chains look well converged. Next we explore the posterior of the parameters. The vertical line corresponds to the true value (the one used to generate the light curve).

In [4]:
samples = sampler.get_samples()

plot_t0 = hv.Distribution(samples['t0'], 't0') * hv.VLine(true_t0.item())
plot_A = hv.Distribution(samples['A'], 'A') * hv.VLine(true_A.item())
plot_tfall = hv.Distribution(samples['tfall'], 'tfall') * hv.VLine(true_tfall.item()) 
plot_trise = hv.Distribution(samples['trise'], 'trise') * hv.VLine(true_trise.item()) 
(plot_t0 + plot_A + plot_tfall + plot_trise)
Out[4]:

Next we explore the predictive posterior:

In [5]:
predictive = numpyro.infer.Predictive(bayesian_bazin, 
                                      posterior_samples=samples, 
                                      return_sites=['f'])
key, key_ = random.split(key)
posterior_samples = predictive(key_, time_inference, 0.001)

q5, q50, q95 = jnp.quantile(posterior_samples['f'], jnp.array([0.01, 0.5, 0.99]), axis=0)
uncertainty_plot = hv.Spread((time_inference, q50, q95-q5)).opts(fill_alpha=0.5)
observed_mag_plot * uncertainty_plot
Out[5]:

Now let's try Variational Inference with a multivariate normal guide

Important details to facilitate convergence:

  • Reduce the initial scale of the guide
  • initialize to the median of the prior. Default initializes using a uniform distribution on the full domain
In [6]:
import optax
from tqdm.notebook import tqdm

from numpyro.infer.autoguide import AutoDelta, AutoNormal, AutoMultivariateNormal


def train_svi(guide, key, lr=1e-2, nepochs=3000):
    clipped_adam = optax.chain(optax.clip_by_global_norm(10.0),  
                               optax.scale_by_adam(),
                               optax.scale(-lr))
    svi = numpyro.infer.SVI(bayesian_bazin, guide, clipped_adam, 
                            loss=numpyro.infer.Trace_ELBO(num_particles=10))
    state = svi.init(key, time, err, mag)
    loss_evolution = []
    jit_update = jit(svi.update)
    for epoch in tqdm(range(nepochs)):
        state, loss = jit_update(state, time, err, mag)
        loss_evolution.append(loss.item())
        current_params = svi.get_params(state)
    return svi, state, loss_evolution
 
key, key_ = random.split(key)
guide = AutoMultivariateNormal(bayesian_bazin, init_scale=1e-1, 
                               init_loc_fn=init_to_median(num_samples=10))
#guide = AutoDelta(bayesian_bazin, init_loc_fn=init_to_median(num_samples=10))
svi, state, loss_evolution = train_svi(guide, key_)  

hv.Curve(loss_evolution, 'Epoch', 'Loss')
  0%|          | 0/3000 [00:00<?, ?it/s]
Out[6]:

Let's see the posteriors for VI:

In [7]:
predictive = numpyro.infer.Predictive(bayesian_bazin, 
                                      guide=svi.guide, 
                                      params=svi.get_params(state), 
                                      return_sites=['t0', 'A', 'tfall', 'trise', 'f'],
                                      num_samples=1000)

key, key_ = random.split(key)
posterior_samples = predictive(key_, time_inference, )

plot_t0 = hv.Distribution(posterior_samples['t0'], 't0') * hv.VLine(true_t0.item())
plot_A = hv.Distribution(posterior_samples['A'], 'A') * hv.VLine(true_A.item())
plot_tfall = hv.Distribution(posterior_samples['tfall'], 'tfall') * hv.VLine(true_tfall.item()) 
plot_trise = hv.Distribution(posterior_samples['trise'], 'trise') * hv.VLine(true_trise.item()) 
(plot_t0 + plot_A + plot_tfall + plot_trise)
Out[7]:
In [8]:
q5, q50, q95 = jnp.quantile(posterior_samples['f'], jnp.array([0.01, 0.5, 0.99]), axis=0)
uncertainty_plot = hv.Spread((time_inference, q50, q95-q5)).opts(fill_alpha=0.5)
observed_mag_plot * uncertainty_plot
Out[8]: